# -*- coding: utf-8 -*-
# Copyright Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.

from mindspore.ops import operations as P
from ascend_utils.mindspore.quant.ptq_quant.deploy.deploy_quant import DeployQuant


class LinearDeployQuant(DeployQuant):
    """
    convert _Linear to quantize op, that can be exported to air model.
    """
    def __init__(self, simulated_quant_cell):
        super().__init__(simulated_quant_cell)
        self.in_channels = simulated_quant_cell.compute_cell.in_channels
        self.out_channels = simulated_quant_cell.compute_cell.out_channels
        self.activation = simulated_quant_cell.compute_cell.activation
        self.activation_flag = simulated_quant_cell.compute_cell.activation_flag
        self.op_core = P.MatMul()

    def construct(self, input_x):
        input_dtype = self.dtype(input_x)
        quant_x = self.quant(input_x)
        quant_x = self.reshape(quant_x, (-1, self.in_channels))
        quant_y = self.op_core(quant_x, self.weight)
        if self.has_bias:
            quant_y = self.bias_add(quant_y, self.bias)
        dequant_y = self.dequant(quant_y, self.fused_deq_scale)
        output = self.cast(dequant_y, input_dtype)
        if self.activation_flag:
            output = self.activation(output)
        out_shape = self.shape_op(input_x)[:-1] + (self.out_channels,)
        output = self.reshape(output, out_shape)
        return output